{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ULdrhOaVbsdO"
      },
      "source": [
        "# Acme: Quickstart\n",
        "# \u003cdiv align=\"left\"\u003e[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/acme/examples/quickstart.ipynb)\u003c/div\u003e\n",
        "\n",
        "This is a quick guide to installing Acme and a very simple example of running a D4PG agent."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xaJxoatMhJ71"
      },
      "source": [
        "## Installation\n",
        "\n",
        "In the first few cells we'll start by installing all of the necessary dependencies (and a few optional ones)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "KH3O0zcXUeun"
      },
      "outputs": [],
      "source": [
        "#@title Install necessary dependencies.\n",
        "\n",
        "!pip install dm-acme\n",
        "!pip install dm-acme[reverb]\n",
        "!pip install dm-acme[tf]\n",
        "!pip install dm-acme[envs]\n",
        "\n",
        "from IPython.display import clear_output\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VEEj3Qw60y73"
      },
      "source": [
        "### Install dm_control\n",
        "\n",
        "The next cell will install environments provided by `dm_control` _if_ you have an institutional MuJoCo license. This is not necessary, but without this you won't be able to use the `dm_cartpole` environment below and can instead follow this colab using `gym` environments. To do so simply expand the following cell, paste in your license file, and run the cell.\n",
        "\n",
        "Alternatively, Colab supports using a Jupyter kernel on your local machine which can be accomplished by following the guidelines here: https://research.google.com/colaboratory/local-runtimes.html. This will allow you to install `dm_control` by following instructions in https://github.com/deepmind/dm_control and using a personal MuJoCo license.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "IbZxYDxzoz5R"
      },
      "outputs": [],
      "source": [
        "#@title Add your License\n",
        "#@test {\"skip\": true}\n",
        "mjkey = \"\"\"\n",
        "\"\"\".strip()\n",
        "\n",
        "mujoco_dir = \"$HOME/.mujoco\"\n",
        "\n",
        "# Install OpenGL dependencies\n",
        "!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
        "  libgl1-mesa-glx libosmesa6 libglew2.0\n",
        "\n",
        "# Get MuJoCo binaries\n",
        "!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
        "!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
        "\n",
        "# Copy over MuJoCo license\n",
        "!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
        "\n",
        "# Install dm_control\n",
        "!pip install dm_control\n",
        "\n",
        "# Configure dm_control to use the OSMesa rendering backend\n",
        "%env MUJOCO_GL=osmesa\n",
        "\n",
        "# Check that the installation succeeded\n",
        "try:\n",
        "  from dm_control import suite\n",
        "  env = suite.load('cartpole', 'swingup')\n",
        "  pixels = env.physics.render()\n",
        "except Exception as e:\n",
        "  raise e from RuntimeError(\n",
        "      'Something went wrong during installation. Check the shell output above '\n",
        "      'for more information.')\n",
        "else:\n",
        "  from IPython.display import clear_output\n",
        "  clear_output()\n",
        "  del suite, env, pixels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "c-H2d6UZi7Sf"
      },
      "source": [
        "## Import Modules\n",
        "\n",
        "Now we can import all the relevant modules."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "HJ74Id-8MERq"
      },
      "outputs": [],
      "source": [
        "#python3\n",
        "\n",
        "%%capture\n",
        "import copy\n",
        "import IPython\n",
        "\n",
        "\n",
        "from acme import environment_loop\n",
        "from acme import networks\n",
        "from acme.adders import reverb as adders\n",
        "from acme.agents import actors_tf2 as actors\n",
        "from acme.datasets import reverb as datasets\n",
        "from acme.wrappers import gym_wrapper\n",
        "from acme import specs\n",
        "from acme import wrappers\n",
        "from acme.agents import d4pg\n",
        "from acme.agents import agent\n",
        "from acme.utils import tf2_utils\n",
        "from acme.utils import loggers\n",
        "\n",
        "import gym\n",
        "import dm_env\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import reverb\n",
        "import sonnet as snt\n",
        "import tensorflow as tf\n",
        "\n",
        "# Import dm_control if it exists.\n",
        "try:\n",
        "  from dm_control import suite\n",
        "except OSError, ModuleNotFoundError:\n",
        "  pass\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "I6KuVGSk4uc9"
      },
      "source": [
        "## Load an environment\n",
        "\n",
        "We can now load an environment. In what follows we'll create an environment and grab the environment's specifications."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "4PVlHtGF5yzt"
      },
      "outputs": [],
      "source": [
        "environment_name = 'gym_mountaincar'  # @param ['dm_cartpole', 'gym_mountaincar']\n",
        "\n",
        "if 'dm_cartpole' in environment_name:\n",
        "  environment = suite.load('cartpole', 'balance')\n",
        "  environment = wrappers.SinglePrecisionWrapper(environment)\n",
        "  def render(env):\n",
        "    return env._physics.render(camera_id=0)  #pylint: disable=protected-access\n",
        "\n",
        "elif 'gym_mountaincar' in environment_name:\n",
        "  environment = gym_wrapper.GymWrapper(gym.make('MountainCarContinuous-v0'))\n",
        "  environment = wrappers.SinglePrecisionWrapper(environment)\n",
        "  def render(env):\n",
        "    return env.environment.render(mode='rgb_array')\n",
        "else:\n",
        "  raise ValueError('Unknown environment: {}.'.format(environment_name))\n",
        "\n",
        "# Grab the spec of the environment.\n",
        "environment_spec = specs.make_environment_spec(environment)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BukOfOsmtSQn"
      },
      "source": [
        " ## Create a D4PG agent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3Jcjk1w6oHVX"
      },
      "outputs": [],
      "source": [
        "#@title Build agent networks\n",
        "\n",
        "# Get total number of action dimensions from action spec.\n",
        "num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)\n",
        "\n",
        "# Create the shared observation network; here simply a state-less operation.\n",
        "observation_network = tf2_utils.batch_concat\n",
        "\n",
        "# Create the deterministic policy network.\n",
        "policy_network = snt.Sequential([\n",
        "    networks.LayerNormMLP((256, 256, 256), activate_final=True),\n",
        "    networks.NearZeroInitializedLinear(num_dimensions),\n",
        "    networks.TanhToSpec(environment_spec.actions),\n",
        "])\n",
        "\n",
        "# Create the distributional critic network.\n",
        "critic_network = snt.Sequential([\n",
        "    # The multiplexer concatenates the observations/actions.\n",
        "    networks.CriticMultiplexer(),\n",
        "    networks.LayerNormMLP((512, 512, 256), activate_final=True),\n",
        "    networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51),\n",
        "])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4e_ytXw5tlkd"
      },
      "outputs": [],
      "source": [
        "# Create a logger for agent specific diagnostics.\n",
        "agent_logger = loggers.TerminalLogger(label='agent', time_delta=10)\n",
        "\n",
        "# Create the D4PG agent.\n",
        "agent = d4pg.D4PG(\n",
        "    environment_spec=environment_spec,\n",
        "    policy_network=policy_network,\n",
        "    critic_network=critic_network,\n",
        "    observation_network=observation_network,\n",
        "    logger=agent_logger,\n",
        "    checkpoint=False\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oKeGQxzitXYC"
      },
      "source": [
        "## Run a training loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "VWZd5N-Qoz82"
      },
      "outputs": [],
      "source": [
        "# Create a logger for agent specific diagnostics.\n",
        "env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10)\n",
        "\n",
        "env_loop = environment_loop.EnvironmentLoop(environment, agent, logger=env_loop_logger)\n",
        "env_loop.run(num_episodes=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Do57Ql4ZsWDu"
      },
      "source": [
        "## (Optional) Visualize an evaluation loop\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "L_3go1Pd1jwg"
      },
      "outputs": [],
      "source": [
        "# Install and import the necessary dependencies for visualization\n",
        "\n",
        "!sudo apt-get install -y xvfb ffmpeg\n",
        "!pip install 'gym==0.10.11'\n",
        "!pip install imageio\n",
        "!pip install PILLOW\n",
        "!pip install 'pyglet==1.3.2'\n",
        "!pip install pyvirtualdisplay\n",
        "\n",
        "import pyvirtualdisplay\n",
        "import imageio\n",
        "import base64\n",
        "\n",
        "# Set up a virtual display for rendering OpenAI gym environments.\n",
        "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()\n",
        "\n",
        "clear_output()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OIJRbtAlxQVu"
      },
      "outputs": [],
      "source": [
        "def display_video(frames, filename='temp.mp4'):\n",
        "  \"\"\"Save and display video.\"\"\"\n",
        "  # Write video\n",
        "  with imageio.get_writer(filename, fps=60) as video:\n",
        "    for frame in frames:\n",
        "      video.append_data(frame)\n",
        "  # Read video and display the video\n",
        "  video = open(filename, 'rb').read()\n",
        "  b64_video = base64.b64encode(video)\n",
        "  video_tag = ('\u003cvideo  width=\"320\" height=\"240\" controls alt=\"test\" '\n",
        "               'src=\"data:video/mp4;base64,{0}\"\u003e').format(b64_video.decode())\n",
        "  return IPython.display.HTML(video_tag)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2__mFiraWND1"
      },
      "outputs": [],
      "source": [
        "# Run the actor in the environment for desired number of steps.\n",
        "frames = []\n",
        "num_steps = 100\n",
        "timestep = environment.reset()\n",
        "\n",
        "for _ in range(num_steps):\n",
        "  frames.append(render(environment))\n",
        "  action = agent.select_action(timestep.observation)\n",
        "  timestep = environment.step(action)\n",
        "\n",
        "# Save video of the behaviour.\n",
        "display_video(np.array(frames))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Acme: Quickstart",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
